Skip to content

[AMD] Enable preshuffle paged MQA and page_size=64 for NSA indexer#23562

Open
1am9trash wants to merge 12 commits intosgl-project:mainfrom
1am9trash:preshuffle-indexer-kvblocksize16
Open

[AMD] Enable preshuffle paged MQA and page_size=64 for NSA indexer#23562
1am9trash wants to merge 12 commits intosgl-project:mainfrom
1am9trash:preshuffle-indexer-kvblocksize16

Conversation

@1am9trash
Copy link
Copy Markdown
Collaborator

@1am9trash 1am9trash commented Apr 23, 2026

Depends on aiter PR#2879 — must be merged first.

Motivation

In the NSA indexer, paged mqa computation is very slow at high concurrency. For example, at 8k1kcc64, it takes ~88 us per layer, consisting of two kernels (logits tensor init 11us + MQA kernel 77us).

image

The main bottleneck is the oversized block_table. The block_table shape is (batch_size, max_seq_len / page_size). With page_size=1 and max_seq_len=131k at cc64, the block_table shape is (64, 131072), totaling 64 × 131072 × 4B ≈ 32MB, with frequent indirect loads, causing poor MQA kernel performance.

Modifications

To reduce the block_table size, we cannot use page_size=1. In this PR, we change it to 64. This introduces the following changes:

  • The original MQA kernel does not support page_size>1. We switch to the MQA preshuffle kernel, which supports page_size as a multiple of 16.
  • Since the MQA preshuffle kernel requires a specific k-cache layout, the indexer's k-cache read/write paths must also store data accordingly. This was not supported in the existing kernels. In aiter PR#2879, we add preshuffle support to the indexer k-cache read/write kernels (controlled by preshuffle=True).

Other changes:

  • Replace torch.full(..., -inf) with torch.empty(...) to eliminate an unnecessary initialization kernel.
  • In this PR, we also update page_size initialization and corresponding assertions.

Accuracy Tests

GLM-5.1-FP8 launch cmd
export SGLANG_ROCM_FUSED_DECODE_MLA=0
export ROCM_QUICK_REDUCE_QUANTIZATION=INT4
export SAFETENSORS_FAST_GPU=1
python3 -m sglang.launch_server \
  --model-path GLM-5.1-FP8 \
  --tp 8 --port 9000 --trust-remote-code \
  --tool-call-parser glm47 --reasoning-parser glm45 \
  --mem-fraction-static 0.85 \
  --model-loader-extra-config '{"enable_multithread_load": true, "num_threads": 8}' \
  --nsa-prefill-backend tilelang --nsa-decode-backend tilelang --disable-radix-cache \
  --kv-cache-dtype fp8_e4m3

MI355 GSM8k (TP8): 0.951

Speed Tests and Profiling

Per-layer profiling:

  • Kernels: 2 -> 1

Critical case (8k1kcc64):

  • Time per layer: 88us -> 19us
  • e2e +16% throughput.
image

Benchmark on MI355X TP8, concurrency 4/8/16/32/64 averaged:

  • ISL/OSL 1k/1k: Throughput +3.83%, TPOT -3.26%
  • ISL/OSL 8k/1k: Throughput +5.61%, TPOT -6.65%

Checklist

Review and Merge Process

  1. Ping Merge Oncalls to start the process. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • Common commands include /tag-and-rerun-ci, /tag-run-ci-label, /rerun-failed-ci
  4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a 'preshuffle' layout for DeepSeek DSA on HIP, integrating AITER-based kernels for gathering and storing K/S data and updating the default page size to 64. A significant issue was identified where falling back to Triton kernels on HIP (when AITER is disabled) would result in a layout mismatch, as these kernels do not yet support the preshuffle layout required by the MQA kernel.

Comment thread python/sglang/srt/layers/attention/nsa/index_buf_accessor.py
Jacob0226 added a commit to Jacob0226/sglang that referenced this pull request Apr 29, 2026
Squashed implementation of three HIP-only optimizations that together
shrink the GLM-5-FP8 NSA tilelang decode layer on MI355X from ~397 us
to ~324 us (-73 us / -18.4%, MI355X TP=8 fp8 KV cache).

==============================================================================
1. fix(rocm): restore `_is_hip` in DeepseekV2Model.alt_stream creation
==============================================================================

Commit a1ceb2e ("[AMD] Enable MoE dual stream overlap on HIP for GLM4/GLM5")
added `_is_hip` to the alt_stream gate. The MUSA backend PR b35213b
("[MUSA][16/N] Add MUSA backend support for layers and DeepSeek models")
was branched off a parent that did not contain a1ceb2e, and on merge
inadvertently dropped `_is_hip` while adding `_is_musa`. Result: on ROCm
`self.alt_stream is None`, so `forward_normal_dual_stream` and the MLA
dual-stream fork are never entered — decode traces show only one
physical stream.

This commit restores `_is_hip` alongside `_is_musa` and re-applies the
`not _use_aiter` guard in `forward_normal_dual_stream`'s
routed_scaling_factor multiply (aiter's biased_grouped_topk already
fuses the scaling, so multiplying again would double it).

Both changes are HIP-only: CUDA / MUSA / NPU branches are unaffected.

==============================================================================
2. perf(rocm-nsa): A_v4 dual-stream layout in forward_absorb_prepare
==============================================================================

Refactor the q_b_proj / NSA-indexer dual-stream fork in
DeepseekMLAForwardMixin.forward_absorb_prepare so that on HIP the
indexer chain on alt overlaps not just with q_b_proj but also with the
gap-fill that follows on cur (bmm w_kc absorb + rotary_emb on q_pe/k_pe,
plus fused_qk_rope_cat_and_cache_mla on the gfx95 NSA tilelang path).

Two HIP-graph capture rules drive the layout (validated by the
microbenchmark in SGLang-benchmarks/tools/glm5_proposalA_v3_test.py
variant A_v4: -18.9 us/layer over the prior layout):

  1. Dispatch order picks the physical stream — the branch dispatched
     first at the fork keeps the predecessor stream (phys 0); the
     later-dispatched branch lands on a fresh aux stream (phys 4).
     We dispatch q_b_proj on cur FIRST and only afterwards enter
     `with stream(alt):` for the indexer.
  2. `alt.wait_stream(cur)` snapshots cur's state at call time. Since
     the indexer needs only q_lora (phase1 output), placing wait_stream
     BEFORE q_b_proj lets alt's heavy indexer chain start the instant
     phase1 completes — in parallel with cur's q_b_proj plus gap-fill,
     instead of waiting for q_b_proj first.

The `cur.wait_stream(alt)` join is moved past rotary_emb so cur's
gap-fill chain overlaps with alt's indexer.

CUDA / MUSA / NPU paths are gated to keep the original PR sgl-project#23562 layout
(byte-identical) — these were not validated under the new schedule.

Drives `overlap_indexer_with_gap_fill` flag used by sub-optimization (3).

==============================================================================
3. perf(rocm-nsa): pull fused_qk_rope_cat_and_cache_mla into the dual-stream
   window, and skip the redundant CatArrayBatchedCopy that follows attn_mqa
==============================================================================

For the gfx95 NSA tilelang fused-rope path, the
`fused_qk_rope_cat_and_cache_mla` kernel that normally runs in
`forward_absorb_core` is moved into `forward_absorb_prepare` so it runs
on cur inside the dual-stream window — overlapping with the alt
indexer instead of running serially after the join. The result is
forwarded from prepare to core via a new optional `fused_qk_kv_cache`
return field; core falls back to the original inline computation when
the prepare-side fast path was not taken (non-capture, non-decode, or
non-HIP).

In addition, `forward_absorb_core` now passes the already-concatenated
`q_cat` directly to `attn_mqa` with `q_rope=None` on the decode path
(prefill keeps the split form because `nsa_backend.forward_extend`
asserts `q_rope is not None`). On the receiving side,
`nsa_backend.forward_decode` is updated to track `q_all` explicitly:

  - When caller passes split q_nope / q_rope (CUDA / non-HIP paths or
    non-decode HIP), q_all is initialized to None and each impl block
    re-cats as before — byte-identical to the pre-patch behavior.
  - When caller passes q_rope=None on HIP decode, q_all is set to a
    zero-copy `q.contiguous().view(...)` of `q_cat` and each impl block
    skips the otherwise-redundant `concat_mla_absorb_q_general` call.

The cat-skip is gated `if q_all is None or not _is_hip` so non-HIP
backends always re-cat (preserves prior behavior bit-exactly).

This eliminates the CatArrayBatchedCopy<OpaqueType<1u>, ...> kernel
that previously fired once per layer per decode step (~5 us/layer)
between fused_qk_rope_cat and main_kernel on ROCm tilelang traces:
390 invocations → 0 in DualStream0429_v2 trace.

==============================================================================
Validation
==============================================================================

  * MI355X TP=8 GLM-5.1-FP8 fp8 KV cache, NSA tilelang decode:
      - Layer latency: ~397 us → ~324 us  (-73 us / -18.4%)
      - 8k1k conc4 TPOT: 24.48 ms median (output throughput 117 tok/s)
      - GSM8K 1200q: 0.953  (PR sgl-project#23562 baseline 0.951)
  * trace: results/.../GLM-5.1-FP8-prof-DualStream0429_v2/
            prof_in8192_out1024_conc4_p8/*-TP-0-DECODE.trace.json.gz
  * Stacks on top of sgl-project#23562 (preshuffled paged MQA + page_size=64) and
    requires aiter PR ROCm/aiter#2879 (preshuffle layout in indexer
    k-cache kernels).

==============================================================================
Files
==============================================================================

  * deepseek_v2.py         (+5 -2)  alt_stream gate + routed_scaling guard
  * forward_mla.py        (+212 -73)  A_v4 layout + fused pull-up + cat-skip
                                      plumbing, HIP-only via `_is_hip` gate
  * nsa_backend.py        (+15 -4)  q_all tracking + cat-skip, HIP-only
Jacob0226 added a commit to Jacob0226/sglang that referenced this pull request Apr 29, 2026
…ual-stream

This commit lands two HIP-only optimizations on top of PR sgl-project#23562:

1. Cat-skip in nsa_backend.forward_decode (default ON, ~2.6 us / layer)
2. A_v4 NSA dual-stream layout (gated OFF by default — regresses on MI355X)

Validated on MI355X TP=8 GLM-5.1-FP8 (8k1k conc4):

  Variant                                 Median TPOT     Δ vs Thomas
  ---------------------------------------------------------------------
  Thomas (PR sgl-project#23562 only)                   21.21 ms        baseline
  This commit, default (cat-skip on,
    dual-stream off)                        20.48 ms        −3.4% (faster)
  This commit + SGLANG_ENABLE_HIP_DUAL_STREAM=1
    + --disable-shared-experts-fusion       24.45 ms        +15.3% (regression)

==============================================================================
1. Cat-skip optimization (default ON, HIP-only)
==============================================================================

In the NSA TileLang fused-rope decode path, fused_qk_rope_cat_and_cache_mla
produces a contiguous `q_cat` tensor of shape (M, num_heads, kv_lora_rank +
qk_rope_head_dim). The pre-patch flow then sliced q_cat into (q_nope_fused,
q_pe_fused) and passed them as separate args to attn_mqa, which causes
nsa_backend.forward_decode to call concat_mla_absorb_q_general(q_nope, q_rope)
to rebuild q_all. On ROCm that fallback hits torch.cat → CatArrayBatchedCopy,
producing a tensor that is byte-identical to the q_cat we already have.

forward_absorb_core now passes q_cat directly to attn_mqa with q_rope=None on
the decode path (prefill keeps the split form because forward_extend asserts
q_rope is not None). nsa_backend.forward_decode is updated to track q_all
explicitly:

  - When caller passes split q_nope / q_rope, q_all=None and each impl block
    re-cats as before — byte-identical to pre-patch behavior.
  - When caller passes q_rope=None on HIP decode, q_all is set to a zero-copy
    `q.contiguous().view(...)` and the cat is skipped.

The cat-skip is gated `if q_all is None or not _is_hip` so non-HIP backends
always re-cat (preserves CUDA / MUSA behavior bit-exactly).

Effect: CatArrayBatchedCopy<OpaqueType<1u>, ...> kernel that previously fired
once per layer per decode step disappears from ROCm tilelang traces.

==============================================================================
2. A_v4 dual-stream layout (opt-in via SGLANG_ENABLE_HIP_DUAL_STREAM=1)
==============================================================================

forward_absorb_prepare gains a HIP-only A_v4 dual-stream layout that overlaps
the NSA indexer chain on alt with [q_b_proj + bmm w_kc + fused_qk_rope_cat]
on cur. Two HIP-graph capture rules drive the layout:

  1. Dispatch order picks the physical stream — the branch dispatched first
     keeps the predecessor stream (phys 0); the later-dispatched branch lands
     on a fresh aux stream (phys 4). q_b_proj is dispatched on cur FIRST,
     then `with stream(alt):` for the indexer.
  2. alt.wait_stream(cur) is placed BEFORE q_b_proj. Indexer needs only
     q_lora (phase1 output), not q_b_proj's q, so alt's heavy indexer chain
     can start the moment phase1 completes — in parallel with cur's q_b_proj
     plus gap-fill.

The cur.wait_stream(alt) join is moved past rotary_emb so cur's gap-fill
chain overlaps with alt's indexer. fused_qk_rope_cat_and_cache_mla is also
pulled from forward_absorb_core into prepare's dual-stream window, with the
result forwarded via a new optional fused_qk_kv_cache return field.

CUDA / MUSA / NPU paths take the original q_b_proj ∥ NSA-indexer layout from
PR sgl-project#23562 base (byte-identical) — the new layout was not validated on those
platforms.

Why opt-in: on MI355X the layout regresses ~30 us / layer due to three
contention sources:

  - HBM bandwidth contention: indexer's memory-bound kernels lose 0.5-2.4 us
    each when sharing HBM with cur GEMMs (+8 us total).
  - Compute-unit split: scheduler partitions 256 CUs across concurrent
    kernels, slowing both compute-bound kernels (+5 us total).
  - HIP-graph AllReduce slowdown: aiter::cross_device_reduce_1stage takes
    23 us under dual-stream graph capture vs 9.5 us single-stream — same
    kernel, same TP=8 topology. Likely caused by the AR's first-stage peer
    fence having to drain alt's KV-cache writes too. ~+26 us / layer (2 ARs).

Theoretical A_v4 saving (gap-fill ∥ indexer ≈ −10 us / layer) is dwarfed
by these costs. The layout is preserved behind SGLANG_ENABLE_HIP_DUAL_STREAM
for future ROCm releases that may fix the AR fence cost.

To enable for testing:

  SGLANG_ENABLE_HIP_DUAL_STREAM=1 ./GLM.sh --dual-stream-rocm ...

==============================================================================
Files changed
==============================================================================

  environ.py            (+8)   New env var SGLANG_ENABLE_HIP_DUAL_STREAM
  deepseek_v2.py        (+15 -2)
                              alt_stream gate now requires _is_hip + env var.
                              forward_normal_dual_stream's routed_scaling
                              multiply also adds `not _use_aiter` (aiter's
                              biased_grouped_topk already fuses the scaling).
  forward_mla.py        (+212 -73)
                              A_v4 layout in forward_absorb_prepare (gated on
                              _is_hip; degrades to serial when alt_stream is
                              None). fused_qk_rope_cat pull-up + q_rope=None
                              cat-skip plumbing in forward_absorb_core.
  nsa_backend.py        (+15 -4)
                              q_all tracking + cat-skip in forward_decode.
                              HIP-only — non-HIP always re-cats.

Stacks on top of PR sgl-project#23562 (preshuffled paged MQA + page_size=64) and
requires aiter PR ROCm/aiter#2879 (preshuffle layout in indexer k-cache
kernels).

Detailed regression analysis:
  ~/SGLang-benchmarks/tmp/dual_stream_regression_analysis.md
Jacob0226 added a commit to Jacob0226/sglang that referenced this pull request Apr 29, 2026
…ual-stream

This commit lands two HIP-only optimizations on top of PR sgl-project#23562:

1. Cat-skip in nsa_backend.forward_decode (default ON, ~2.6 us / layer)
2. A_v4 NSA dual-stream layout (gated OFF by default — regresses on MI355X)

Validated on MI355X TP=8 GLM-5.1-FP8 (8k1k conc4):

  Variant                                 Median TPOT     Δ vs Thomas
  ---------------------------------------------------------------------
  Thomas (PR sgl-project#23562 only)                   21.21 ms        baseline
  This commit, default (cat-skip on,
    dual-stream off)                        20.48 ms        −3.4% (faster)
  This commit + SGLANG_ENABLE_HIP_DUAL_STREAM=1
    + --disable-shared-experts-fusion       24.45 ms        +15.3% (regression)

==============================================================================
1. Cat-skip optimization (default ON, HIP-only)
==============================================================================

In the NSA TileLang fused-rope decode path, fused_qk_rope_cat_and_cache_mla
produces a contiguous `q_cat` tensor of shape (M, num_heads, kv_lora_rank +
qk_rope_head_dim). The pre-patch flow then sliced q_cat into (q_nope_fused,
q_pe_fused) and passed them as separate args to attn_mqa, which causes
nsa_backend.forward_decode to call concat_mla_absorb_q_general(q_nope, q_rope)
to rebuild q_all. On ROCm that fallback hits torch.cat → CatArrayBatchedCopy,
producing a tensor that is byte-identical to the q_cat we already have.

forward_absorb_core now passes q_cat directly to attn_mqa with q_rope=None on
the decode path (prefill keeps the split form because forward_extend asserts
q_rope is not None). nsa_backend.forward_decode is updated to track q_all
explicitly:

  - When caller passes split q_nope / q_rope, q_all=None and each impl block
    re-cats as before — byte-identical to pre-patch behavior.
  - When caller passes q_rope=None on HIP decode, q_all is set to a zero-copy
    `q.contiguous().view(...)` and the cat is skipped.

The cat-skip is gated `if q_all is None or not _is_hip` so non-HIP backends
always re-cat (preserves CUDA / MUSA behavior bit-exactly).

Effect: CatArrayBatchedCopy<OpaqueType<1u>, ...> kernel that previously fired
once per layer per decode step disappears from ROCm tilelang traces.

==============================================================================
2. A_v4 dual-stream layout (opt-in via SGLANG_ENABLE_HIP_DUAL_STREAM=1)
==============================================================================

forward_absorb_prepare gains a HIP-only A_v4 dual-stream layout that overlaps
the NSA indexer chain on alt with [q_b_proj + bmm w_kc + fused_qk_rope_cat]
on cur. Two HIP-graph capture rules drive the layout:

  1. Dispatch order picks the physical stream — the branch dispatched first
     keeps the predecessor stream (phys 0); the later-dispatched branch lands
     on a fresh aux stream (phys 4). q_b_proj is dispatched on cur FIRST,
     then `with stream(alt):` for the indexer.
  2. alt.wait_stream(cur) is placed BEFORE q_b_proj. Indexer needs only
     q_lora (phase1 output), not q_b_proj's q, so alt's heavy indexer chain
     can start the moment phase1 completes — in parallel with cur's q_b_proj
     plus gap-fill.

The cur.wait_stream(alt) join is moved past rotary_emb so cur's gap-fill
chain overlaps with alt's indexer. fused_qk_rope_cat_and_cache_mla is also
pulled from forward_absorb_core into prepare's dual-stream window, with the
result forwarded via a new optional fused_qk_kv_cache return field.

CUDA / MUSA / NPU paths take the original q_b_proj ∥ NSA-indexer layout from
PR sgl-project#23562 base (byte-identical) — the new layout was not validated on those
platforms.

Why opt-in: on MI355X the layout regresses ~30 us / layer due to three
contention sources:

  - HBM bandwidth contention: indexer's memory-bound kernels lose 0.5-2.4 us
    each when sharing HBM with cur GEMMs (+8 us total).
  - Compute-unit split: scheduler partitions 304 CUs across concurrent
    kernels, slowing both compute-bound kernels (+5 us total).
  - HIP-graph AllReduce slowdown: aiter::cross_device_reduce_1stage takes
    23 us under dual-stream graph capture vs 9.5 us single-stream — same
    kernel, same TP=8 topology. Likely caused by the AR's first-stage peer
    fence having to drain alt's KV-cache writes too. ~+26 us / layer (2 ARs).

Theoretical A_v4 saving (gap-fill ∥ indexer ≈ −10 us / layer) is dwarfed
by these costs. The layout is preserved behind SGLANG_ENABLE_HIP_DUAL_STREAM
for future ROCm releases that may fix the AR fence cost.

To enable for testing:

  SGLANG_ENABLE_HIP_DUAL_STREAM=1 ./GLM.sh --dual-stream-rocm ...

==============================================================================
Files changed
==============================================================================

  environ.py            (+8)   New env var SGLANG_ENABLE_HIP_DUAL_STREAM
  deepseek_v2.py        (+15 -2)
                              alt_stream gate now requires _is_hip + env var.
                              forward_normal_dual_stream's routed_scaling
                              multiply also adds `not _use_aiter` (aiter's
                              biased_grouped_topk already fuses the scaling).
  forward_mla.py        (+212 -73)
                              A_v4 layout in forward_absorb_prepare (gated on
                              _is_hip; degrades to serial when alt_stream is
                              None). fused_qk_rope_cat pull-up + q_rope=None
                              cat-skip plumbing in forward_absorb_core.
  nsa_backend.py        (+15 -4)
                              q_all tracking + cat-skip in forward_decode.
                              HIP-only — non-HIP always re-cats.

Stacks on top of PR sgl-project#23562 (preshuffled paged MQA + page_size=64) and
requires aiter PR ROCm/aiter#2879 (preshuffle layout in indexer k-cache
kernels).

Detailed regression analysis:
  ~/SGLang-benchmarks/tmp/dual_stream_regression_analysis.md
@1am9trash 1am9trash changed the title [AMD][DO-NOT-MERGE] Enable preshuffle paged MQA and page_size=64 for NSA indexer [AMD] Enable preshuffle paged MQA and page_size=64 for NSA indexer May 4, 2026
Jacob0226 added a commit to Jacob0226/sglang that referenced this pull request May 6, 2026
GLM-5 NSA TileLang decode on ROCm dispatches a `CatArrayBatchedCopy` kernel
once per layer per decode step that rebuilds an already-existing tensor.
This is a strict-improvement bug fix: ~2.6 us / layer saved, 0 changes for
non-HIP backends.

==============================================================================
Root cause
==============================================================================

For the NSA TileLang fused-rope decode path (`_use_aiter_gfx95 + nsa +
nsa_decode_backend == "tilelang"`), `forward_absorb_core` calls
`fused_qk_rope_cat_and_cache_mla` which produces a contiguous q_cat tensor
of shape (M, num_heads, kv_lora_rank + qk_rope_head_dim). The pre-patch
flow then sliced q_cat into q_nope_fused / q_pe_fused and passed them as
separate args to attn_mqa.

attn_mqa -> NSABackend.forward_decode then takes the if-branch (q_rope
is not None), views the slices, and for tilelang / flashmla_sparse /
flashmla_kv / aiter decode impls calls
`concat_mla_absorb_q_general(q_nope, q_rope)` to rebuild q_all. On ROCm,
that helper falls back to `torch.cat([q_nope, q_rope], dim=-1)`, which
allocates a fresh contiguous tensor and dispatches a copy kernel. The
result is byte-identical to the q_cat we already had — the cat is pure
overhead.

==============================================================================
Fix
==============================================================================

(1) `forward_absorb_core` now passes q_cat directly to attn_mqa with
    q_rope=None on the decode path. Prefill (forward_extend) keeps the
    split form because `nsa_backend.forward_extend` asserts
    `q_rope is not None`.

(2) `nsa_backend.forward_decode` is updated to track q_all explicitly:

    - When the caller passes split q_nope / q_rope, q_all is initialized
      to None and each impl block re-cats as before (byte-identical to
      pre-patch behavior).
    - When the caller passes q_rope=None on HIP, q_all is set to a
      zero-copy `q.contiguous().view(...)` and the cat is skipped.

    The cat-skip is gated `if q_all is None or not _is_hip` so non-HIP
    backends always re-cat (preserves CUDA / MUSA paths bit-exactly).

==============================================================================
Validation
==============================================================================

MI355X TP=8 GLM-5.1-FP8 fp8 KV cache, NSA TileLang decode (on top of
PR sgl-project#23562 + aiter PR sgl-project#2879):

   scenario              | before    | after     | TPOT  Δ
   --------------------- | --------- | --------- | --------
   8k1k conc4   TPOT     | 21.21 ms  | 20.76 ms  | -2.17%
   8k1k conc8   TPOT     | 25.28 ms  | 24.82 ms  | -1.82%
   8k1k conc16  TPOT     | 30.79 ms  | 30.33 ms  | -1.49%
   8k1k conc32  TPOT     | 42.92 ms  | 42.46 ms  | -1.07%
   8k1k conc64  TPOT     | 61.79 ms  | 61.33 ms  | -0.74%
   1k1k conc4   TPOT     | 18.79 ms  | 18.33 ms  | -2.45%
   1k1k conc8   TPOT     | 21.14 ms  | 20.66 ms  | -2.27%
   1k1k conc16  TPOT     | 23.63 ms  | 23.15 ms  | -2.03%
   1k1k conc32  TPOT     | 29.19 ms  | 28.69 ms  | -1.71%
   1k1k conc64  TPOT     | 35.02 ms  | 34.60 ms  | -1.20%

Output throughput improves by the same percentage on every scenario.
Cat-skip's absolute ~2.6 us / layer benefit is constant; the relative
gain is highest at small batch + short prompt (where total layer time is
smallest) and decays with batch size.

GSM8K accuracy: 0.942 vs 0.951 baseline (within run-to-run variance
observed across multiple runs of the same config: 0.946-0.953).

==============================================================================
Files
==============================================================================

  forward_mla.py    (+50 -16)  forward_absorb_core:_skip_rope_for_nsa_tilelang_fused
                                branch passes q_cat with q_rope=None for decode.
  nsa_backend.py    (+12  -4)  forward_decode tracks q_all and skips cat on HIP
                                when caller already provided concatenated q.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants